#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2015 Pivotal Software, Inc. All Rights Reserved
#
# This software contains the intellectual property of Pivotal Software, Inc.
# or is licensed to Pivotal Software, Inc. from third parties. Use of this
# software and the intellectual property contained therein is expressly
# limited to the terms and conditions of the License Agreement under which
# it is provided by or on behalf of Pivotal Software, Inc.


# import mainUtils FIRST to get python version check
from gppylib.mainUtils import *
from optparse import OptionParser
from Queue import Queue, Empty
import os
import re
import shutil
import fnmatch
import tempfile
import time
from datetime import datetime, timedelta
import pipes  # for shell-quoting, pipes.quote()
import fcntl

try:
    from gppylib import gplog, pgconf, userinput
    from gppylib.commands.base import Command, WorkerPool, Worker
    from gppylib.operations import Operation
    from gppylib.gpversion import GpVersion
    from gppylib.db import dbconn
    from gppylib.operations.unix import CheckDir, CheckFile, MakeDir
    from pygresql import pg

except ImportError, e:
    sys.exit('Cannot import modules.  Please check that you have sourced greenplum_path.sh.  Detail: ' + str(e))

EXECNAME = 'analyzedb'
STATEFILE_DIR = 'db_analyze'
logger = gplog.get_default_logger()
WRITE_LOCK_FILE_NAME = "write_lock_semaphore"
ANALYZE_GUCS = "set optimizer_analyze_enable_merge_of_leaf_stats=off; "
ANALYZE_SQL = """analyze %s"""
ANALYZE_ROOT_SQL = """analyze rootpartition %s"""
REPORTS_ARE_STALE_AFTER_N_DAYS = 8
NUM_REPORTS_TO_SAVE = 3

PG_PARTITIONS_SURROGATE = """
SELECT n.nspname AS schemaname, cl.relname AS tablename, n2.nspname AS partitionschemaname, cl2.relname AS partitiontablename, cl3.relname AS parentpartitiontablename
FROM pg_namespace n, pg_namespace n2, pg_class cl, pg_class cl2, pg_partition pp,
pg_partition_rule pr1
LEFT JOIN pg_partition_rule pr2 ON pr1.parparentrule = pr2.oid
LEFT JOIN pg_class cl3 ON pr2.parchildrelid = cl3.oid
WHERE pp.paristemplate = false AND pp.parrelid = cl.oid AND pr1.paroid = pp.oid AND cl2.oid = pr1.parchildrelid AND cl.relnamespace = n.oid AND cl2.relnamespace = n2.oid
"""

GET_ALL_DATA_TABLES_SQL = """
select n.nspname as schemaname, c.relname as tablename from pg_class c, pg_namespace n where
c.relnamespace = n.oid and c.relkind='r'::char and (c.relnamespace >= 16384 or n.nspname = 'public' or n.nspname = 'pg_catalog') and c.oid not in (select reloid from pg_exttable)
EXCEPT
select distinct schemaname, tablename from (%s) AS pps1
EXCEPT
select distinct partitionschemaname, parentpartitiontablename from (%s) AS pps2 where parentpartitiontablename is not NULL
""" % (PG_PARTITIONS_SURROGATE, PG_PARTITIONS_SURROGATE)

GET_VALID_DATA_TABLES_SQL = """
select n.nspname as schemaname, c.relname as tablename from pg_class c, pg_namespace n where
c.relnamespace = n.oid and c.oid in (%s) and c.relkind='r'::char and (c.relnamespace >= 16384 or n.nspname = 'public' or n.nspname = 'pg_catalog') and c.oid not in (select reloid from pg_exttable)
"""

GET_REQUESTED_AO_DATA_TABLE_INFO_SQL = """
    SELECT ALL_DATA_TABLES.oid, ALL_DATA_TABLES.schemaname, ALL_DATA_TABLES.tablename, OUTER_PG_CLASS.relname as tupletable FROM
    (
    select c.oid as oid, n.nspname as schemaname, c.relname as tablename from pg_class c, pg_namespace n where
    c.relnamespace = n.oid and c.oid in (%s)
    ) as ALL_DATA_TABLES, pg_appendonly, pg_class OUTER_PG_CLASS
    WHERE ALL_DATA_TABLES.oid = pg_appendonly.relid
    AND OUTER_PG_CLASS.oid = pg_appendonly.segrelid
"""

GET_REQUESTED_LAST_OP_INFO_SQL = """
    SELECT PGN.nspname, PGC.relname, objid, staactionname, stasubtype, statime FROM pg_stat_last_operation, pg_class PGC, pg_namespace PGN
    WHERE objid = PGC.oid
    AND objid in (%s)
    AND PGC.relnamespace = PGN.oid
    AND staactionname IN ('CREATE', 'ALTER', 'TRUNCATE')
    ORDER BY objid, staactionname
"""

GET_ALL_DATA_TABLES_IN_SCHEMA_SQL = """
select n.nspname as schemaname, c.relname as tablename from pg_class c, pg_namespace n where
c.relnamespace = n.oid and c.relkind='r'::char and (c.relnamespace >= 16384 or n.nspname = 'public' or n.nspname = 'pg_catalog') and c.oid not in (select reloid from pg_exttable)
and n.nspname = '%s'
EXCEPT
select distinct schemaname, tablename from (%s) AS pps1
EXCEPT
select distinct partitionschemaname, parentpartitiontablename from (%s) AS pps2 where parentpartitiontablename is not NULL
""" % ('%s', PG_PARTITIONS_SURROGATE, PG_PARTITIONS_SURROGATE)

GET_LEAF_PARTITIONS_SQL = """
select partitionschemaname, partitiontablename from (%s) AS pps1 where schemaname = '%s' and tablename = '%s'
EXCEPT
select distinct partitionschemaname, parentpartitiontablename from (%s) AS pps2 where parentpartitiontablename is not NULL
and schemaname = '%s' and tablename = '%s'
""" % (PG_PARTITIONS_SURROGATE, '%s', '%s', PG_PARTITIONS_SURROGATE, '%s', '%s')

GET_MID_LEVEL_PARTITIONS_SQL = """
select distinct partitionschemaname, parentpartitiontablename from (%s) AS pps1 where parentpartitiontablename is not NULL
""" % PG_PARTITIONS_SURROGATE

GET_REQUESTED_NON_AO_TABLES_SQL = """
select n.nspname as schemaname, c.relname as tablename from pg_class c, pg_namespace n where
c.relnamespace = n.oid and c.relkind='r'::char and (c.relnamespace >= 16384 or n.nspname = 'public' or n.nspname = 'pg_catalog')
and c.oid not in (select relid from pg_appendonly) and c.oid in (%s) and c.oid not in (select reloid from pg_exttable)
EXCEPT
select distinct schemaname, tablename from (%s) AS pps1
EXCEPT
select distinct partitionschemaname, parentpartitiontablename from (%s) AS pps2 where parentpartitiontablename is not NULL
""" % ('%s', PG_PARTITIONS_SURROGATE, PG_PARTITIONS_SURROGATE)

GET_COLUMN_NAMES_SQL = """
SELECT attname FROM pg_attribute WHERE attrelid = %s AND attnum > 0 AND NOT attisdropped
"""

GET_SCHEMA_WITH_TEMP_TABLE_SQL = """
SELECT n.nspname
FROM pg_class c, pg_namespace n
WHERE c.relnamespace = n.oid AND c.relpersistence = 't';
"""

GET_INCLUDED_COLUMNS_FROM_EXCLUDE_SQL = """
SELECT attname FROM pg_attribute WHERE attrelid = %s AND attname NOT IN (%s) AND attnum > 0 AND NOT attisdropped
"""

VALIDATE_COLUMN_NAMES_SQL = """
SELECT count(*) FROM pg_attribute WHERE attrelid = %s AND attname IN (%s) AND attnum > 0 AND NOT attisdropped
"""

VALIDATE_TABLE_NAMES_SQL = """
SELECT n.nspname, c.relname, orig_name
FROM pg_class c
INNER JOIN pg_namespace n ON c.relnamespace = n.oid
INNER JOIN (VALUES %s) o(orig_name) ON c.oid = o.orig_name::regclass
"""

GET_LEAF_ROOT_MAPPING_SQL = """
SELECT n.nspname, c2.relname, n.nspname, c.relname from pg_class c, pg_class c2, pg_namespace n, pg_partition pp, pg_partition_rule ppr
WHERE ppr.parchildrelid in (%s) AND ppr.paroid = pp.oid AND pp.parrelid = c.oid AND c.relnamespace = n.oid AND ppr.parchildrelid = c2.oid;
"""

GET_ALL_ROOT_PARTITION_TABLES_SQL = """
SELECT distinct n.nspname || '.' || c.relname from pg_class c, pg_namespace n, pg_partition pp
WHERE pp.parrelid = c.oid AND c.relnamespace = n.oid AND pp.paristemplate = false
"""

ORDER_CANDIDATES_BY_OID_SQL = """
SELECT schemaname, tablename FROM
(SELECT c.oid as tableoid, n.nspname as schemaname, c.relname as tablename FROM pg_class c, pg_namespace n where c.relnamespace=n.oid and c.oid in (%s)) AS foo
ORDER BY tableoid DESC;
"""

def validate_schema_exists(pg_port, dbname, schema):
    conn = None
    try:
        dburl = dbconn.DbURL(port=pg_port, dbname=dbname)
        conn = dbconn.connect(dburl)
        count = dbconn.execSQLForSingleton(conn, "select count(*) from pg_namespace where nspname='%s';" % pg.escape_string(schema))
        if count == 0:
            raise ExceptionNoStackTraceNeeded("Schema %s does not exist in database %s." % (schema, dbname))
    finally:
        if conn:
            conn.close()

def execute_sql(query, master_port, dbname):
    dburl = dbconn.DbURL(port=master_port, dbname=dbname)
    conn = dbconn.connect(dburl)
    cursor = dbconn.execSQL(conn, query)
    return cursor.fetchall()

def get_lines_from_file(fname):
    content = []
    with open(fname) as fd:
        for line in fd:
            content.append(line.strip('\n'))
    return content

def compare_dict(last_dict, curr_dict):
    diffkeys = set()
    for k in curr_dict:
        if k not in last_dict or (curr_dict[k] != last_dict[k]):
            diffkeys.add(k)
    return diffkeys

def write_lines_to_file(filename, lines):
    with open(filename, 'w') as fp:
        for line in lines:
            fp.write("%s\n" % line.strip('\n'))

def verify_lines_in_file(fname, expected):
    lines = get_lines_from_file(fname)

    if lines != expected:
        raise Exception("After writing file '%s' contents not as expected.\nLines read from file: %s\nLines expected from file: %s\n" % (fname, lines, expected))

def get_partition_state_tuples(pg_port, dbname, catalog_schema, partition_info):
    """
    Reads the partition state for an AO or AOCS relation, which is the sum of
    the modication counters over all ao segment files.
    The sum might be an invalid number even when the relation contains tuples.
    The reason is that the master aoseg info tuple for segno 0 is not there
    after the CTAS. Vacuum will correct in missing state on the master.

    Thus, partition state returns 0 also when the aoseg relation is empty.
    Why is that correct?
    A table that as a modcount of 0 can only be there iff the last operation
    was a special operation (TRUNCATE, CREATE, ALTER TABLE). Every DML operation
    will increase the modcount by 1. Therefore it is save to assume that to
    relations with modcount 0 with the same last special operation do not have a
    logical change in them.

    The result is a list of tuples, of the format (schema_schema, partition_name, modcount)
    """
    partition_list = list()
    dburl = dbconn.DbURL(port=pg_port, dbname=dbname)
    num_sqls = 0
    with dbconn.connect(dburl) as conn:
        for (oid, schemaname, partition_name, tupletable) in partition_info:
            try:
                modcount_sql = "select to_char(coalesce(sum(modcount::bigint), 0), '999999999999999999999') from %s.%s" % (catalog_schema, tupletable)
                modcount = dbconn.execSQLForSingleton(conn, modcount_sql)
            except pg.DatabaseError as e:
                if "does not exist" in str(e):
                    logger.info("Table %s.%s (%s) no longer exists and will not be analyzed", schemaname, partition_name, tupletable)
                else:
                    logger.error(str(e))
                # If there's an exception, the transaction is closed so we need to reconnect
                conn = dbconn.connect(dburl)
            else:
                num_sqls += 1
                if num_sqls == 1000: # The choice of batch size was chosen arbitrarily
                    logger.debug('Completed executing batch of 1000 tuple count SQLs')
                    conn.commit()
                    num_sqls = 0
                if modcount:
                    modcount = modcount.strip()
                validate_modcount(schemaname, partition_name, modcount)
                partition_list.append((schemaname, partition_name, modcount))
    return partition_list

def validate_modcount(schema, tablename, cnt):
    if not cnt:
        return
    if not cnt.isdigit():
        raise Exception("Can not convert modification count for table. Possibly exceeded  backup max tuple count of 1 quadrillion rows per table for: '%s.%s' '%s'" % (schema, tablename, cnt))
    if len(cnt) > 15:
        raise Exception("Exceeded backup max tuple count of 1 quadrillion rows per table for: '%s.%s' '%s'" % (schema, tablename, cnt))

class AnalyzeDb(Operation):
    def __init__(self, options, args):

        if args:
            logger.warn("Please note that some of the arguments (%s) aren't valid and will be ignored.", args)
        if not options.rootstats:
            logger.warn("The --skip_root_stats option is no longer supported and will be ignored.")
        if options.masterDataDirectory is None:
            options.masterDataDirectory = gp.get_masterdatadir()
        self.master_datadir = options.masterDataDirectory
        self.analyze_dir = STATEFILE_DIR
        self.pg_port = self._get_pgport()
        self.dbname = options.dbname
        self.schema = options.schema
        self.single_table = options.single_table
        self.config_file = options.config_file
        self.entire_db = False
        self.include_cols = options.include_cols
        self.exclude_cols = options.exclude_cols
        self.full_analyze = options.full_analyze
        self.dry_run = options.dry_run
        self.parallel_level = options.parallel_level
        self.silent = options.silent
        self.verbose = options.verbose
        self.clean_last = options.clean_last
        self.clean_all = options.clean_all
        self.orca_rootstats = options.orca_rootstats
        self.gen_profile_only = options.gen_profile_only
        self.analyze_gucs = ANALYZE_GUCS

        self.success_list = []

        self._validate_options()
        self._preprocess_options()

        self.conn = dbconn.connect(dbconn.DbURL(port=self.pg_port, dbname=self.dbname), utility=False)

    def _validate_options(self):
        """
        Validates the options passed in to the application.
        """
        if not self.dbname:
            raise ProgramArgumentValidationException("option -d required. Please see 'analyzedb -?' for usage.")

        if self.clean_all + self.clean_last > 1:
            raise ProgramArgumentValidationException('options --clean_last and --clean_all are mutually exclusive')

        if (self.schema is not None) + (self.single_table is not None) \
                + (self.config_file is not None) > 1:
            raise ProgramArgumentValidationException('options -s, -t and -f are mutually exclusive')

        if (self.include_cols is not None) + (self.exclude_cols is not None) > 1:
            raise ProgramArgumentValidationException('options -i and -x are mutually exclusive')

        if (self.include_cols is not None or self.exclude_cols is not None) and self.single_table is None:
            raise ProgramArgumentValidationException('option -i or -x can only be used together with -t')

        if self.parallel_level < 1 or self.parallel_level > 10:
            raise ProgramArgumentValidationException('option -p requires a value between 1 and 10')

    def _preprocess_options(self):

        if self.clean_all:
            analyze_folder = os.path.join(self.master_datadir, self.analyze_dir, self.dbname)
            if os.path.exists(analyze_folder):
                if not self.silent \
                        and not userinput.ask_yesno(None,
                                                    "\nDeleting all files and folders in %s ?" % analyze_folder,
                                                    'N'):
                    raise UserAbortedException()
                for f in os.listdir(analyze_folder):
                    f_path = os.path.join(analyze_folder, f)
                    if os.path.isdir(f_path):
                        shutil.rmtree(f_path)
                    else:
                        os.remove(f_path)
            else:
                logger.warning("Folder %s does not exist. Exiting...")

        if self.clean_last:
            last_analyze_timestamp = get_latest_analyze_timestamp(self.master_datadir, self.analyze_dir, self.dbname)
            if last_analyze_timestamp is not None:
                analyze_folder = os.path.join(self.master_datadir, self.analyze_dir, self.dbname,
                                              last_analyze_timestamp)
                if os.path.exists(analyze_folder):
                    if not self.silent \
                            and not userinput.ask_yesno(None,
                                                        "\nDeleting folder %s and all files inside?" % analyze_folder,
                                                        'N'):
                        raise UserAbortedException()
                    shutil.rmtree(analyze_folder)
                else:
                    logger.warning("Folder %s does not exist. Exiting...")
            else:
                logger.warning("No valid state files directories exist. Exiting...")

        if self.include_cols is not None:
            self.include_cols = self.include_cols.strip().split(',')

        if self.exclude_cols is not None:
            self.exclude_cols = self.exclude_cols.strip().split(',')

        if self.single_table is None and self.schema is None and self.config_file is None:
            self.entire_db = True

        if self.verbose:
            logger.setLevel(10)

    def execute(self):
        try:
            if self.clean_all or self.clean_last:
                return 0

            # The input_col_dict keeps track of the requested columns to analyze.
            # key: table name (e.g. 'public.foo')
            # value: a set of requested column names (e.g. set(['col1','col2']), or '-1' indicating all columns
            input_col_dict = {}

            # parse input and update input column dictionary
            input_tables = self._get_input_tables(input_col_dict)  # ['public.foo', 'public.bar', ...]
            input_tables_set = set(input_tables)

            if len(input_tables_set) == 0:
                logger.warning("There are no tables or partitions to be analyzed. Exiting...")
                return 0

            logger.info("Checking for tables with stale stats...")
            # get all heap tables in the requested tables. all heap tables are regarded as dirty
            heap_partitions = get_heap_tables_set(self.conn, input_tables_set)  # set((schema1,table1), ...])

            # get the current state of the requested tables
            # curr_ao_state contains the number of DML commands that have been executed on an AO table
            # curr_last_op contains the timestamp of the last DDL command (CREATE, ALTER, TRUNCATE, DROP) of an AO table
            curr_ao_state = self._get_ao_state(input_tables_set)  # (('schema', 'table', modcount), ...)
            curr_last_op = self._get_lastop_state(
                input_tables_set)  # e.g. ['public,ao_tab,67468,ALTER,ADD COLUMN,2014-10-15 14:49:27.658777-07', ...]

            last_analyze_timestamp, prev_ao_state, prev_col_dict, prev_last_op = self.read_last_analyzedb_output()

            # compare two states to get dirty tables
            dirty_partitions = self._get_dirty_data_tables(heap_partitions, curr_ao_state, curr_last_op, prev_ao_state,
                                                           prev_last_op)

            candidates = set()  # set(['public.foo', 'public.bar', ...])

            if self.full_analyze:  # full analyze does not support column-level increments and will invalidate previous column-level state
                candidates = input_tables_set
            else:  # incremental
                for schema_table in input_tables_set:
                    if schema_table in dirty_partitions:  # for dirty partitions, we invalidate all previous column-level state
                        candidates.add(schema_table)
                    else:
                        # figure out which columns need to be analyzed
                        self._update_input_col_dict_with_column_increments(schema_table, input_col_dict, prev_col_dict)
                        if len(input_col_dict[schema_table]) > 0:
                            candidates.add(schema_table)

            candidates = self._get_valid_candidates(candidates)
            if len(candidates) < 1:
                logger.warning("There are no tables or partitions to be analyzed. Exiting...")
                return 0

            root_partition_col_dict = {}
            # root_partition_col_dict contains the mapping between the root partitions
            # and its corresponding columns to be analyzed
            # key: name of the root partition whose stats needs to be refreshed
            # value: a set of column names to be analyzed, or '-1' meaning all columns of that table
            # this can be suppressed with a flag if ORCA is not used
            # (performance improvement when there are many partitions and columns)
            if self.orca_rootstats:
                root_partition_col_dict = self._get_root_partition_col_dict(candidates, input_col_dict)

            ordered_candidates = self._get_ordered_candidates(candidates, root_partition_col_dict)
            target_list = []
            logger.info("---------------------------------------------------")
            logger.info("Tables or partitions to be analyzed")
            logger.info("---------------------------------------------------")
            for can in ordered_candidates:
                can_schema = can[0]
                can_table = can[1]
                if can in candidates:
                    target = self._get_tablename_with_cols(can_schema, can_table, input_col_dict)
                else:  # can in root_partition_col_dict
                    target = self._get_tablename_with_cols(can_schema, can_table, root_partition_col_dict)
                logger.info(target)
                target_list.append(target)
            logger.info("---------------------------------------------------")

            if self.dry_run:
                return 0

            if not self.silent and not userinput.ask_yesno(None, "\nContinue with Analyze", 'N'):
                raise UserAbortedException()
            try:
                if self.gen_profile_only:
                    for can in ordered_candidates:
                        can_schema, can_table = can[0], can[1]
                        subject = can_schema, can_table
                        self.success_list.append(subject)
                else:
                    self.run_analyze(logger, ordered_candidates, candidates, input_col_dict, root_partition_col_dict)

            finally:
                self._write_report(curr_ao_state, curr_last_op, heap_partitions, input_col_dict,
                                   root_partition_col_dict, dirty_partitions, target_list)
                logger.info("Done.")
        except Exception, ex:
            logger.exception(ex)
            raise

        finally:
            if self.conn:
                self.conn.close()

        return 0

    # We run 'analyze <table>' and 'analyze rootpartition <root>' commands in
    # separate batches, as 'analyze rootpartition' commands will go through
    # different codepaths depending on whether all leaf partitions have been
    # analyzed. If all leaves have been analyzed, 'analyze rootpartition'
    # simply merges the HLL statistics. If any leaves have not been analyzed, a
    # sample is taken, which can take longer. Thus we run all 'analyze'
    # statements before 'analyze rootpartition' statments
    def run_analyze(self, logger, ordered_candidates, candidates, input_col_dict, root_partition_col_dict):
        logger.info("Starting analyze with %d workers..." % self.parallel_level)
        non_root_pool = AnalyzeWorkerPool(numWorkers=self.parallel_level)
        root_pool = AnalyzeWorkerPool(numWorkers=self.parallel_level)

        for can in ordered_candidates:
            can_schema, can_table = can[0], can[1]
            if can in candidates:
                target = self._get_tablename_with_cols(can_schema, can_table, input_col_dict)
                cmd = create_psql_command(self.dbname, self.analyze_gucs + ANALYZE_SQL % target)
                non_root_pool.addCommand(cmd)

            else:  # can in root_partition_col_dict
                target = self._get_tablename_with_cols(can_schema, can_table, root_partition_col_dict)
                cmd = create_psql_command(self.dbname, ANALYZE_ROOT_SQL % target)
                root_pool.addCommand(cmd)

            # Also stash the name of the target table in the object, so that it can be extracted
            # from it later.
            cmd.target_schema = can_schema
            cmd.target_table = can_table

        wait_count = num_tables = len(ordered_candidates)
        num_roots = root_pool.assigned
        start_time = time.time()
        non_root_pool.run()
        try:
            while wait_count > num_roots:
                self.run_analyze_statements(non_root_pool, wait_count, num_tables)
                wait_count -= 1
            non_root_pool.join()
            non_root_pool.haltWork()
            root_pool.run()
            while wait_count > 0:
                self.run_analyze_statements(root_pool, wait_count, num_tables)
                wait_count -= 1
            root_pool.join()
            root_pool.haltWork()
        except:
            non_root_pool.haltWork()
            non_root_pool.joinWorkers()
            if non_root_pool.completed_queue.qsize() == 0:
                root_pool.haltWork()
                root_pool.joinWorkers()
            raise

        finally:
            end_time = time.time()
            logger.info(
                        "Total elapsed time: %d seconds. Analyzed %d out of %d table(s) or partition(s) successfully."
                        % (int(end_time - start_time), len(self.success_list), len(ordered_candidates)))

    def run_analyze_statements(self, pool, wait_count, total_to_analyze):
        done_cmd = pool.completed_queue.get()
        if done_cmd.was_successful():
            subject = (done_cmd.target_schema, done_cmd.target_table)
            self.success_list.append(subject)
        if wait_count % 10 == 0:
            logger.info("progress status: completed %d out of %d tables or partitions" %
                        (len(self.success_list), total_to_analyze))

    def read_last_analyzedb_output(self):
        last_analyze_timestamp = get_latest_analyze_timestamp(self.master_datadir, self.analyze_dir, self.dbname)
        prev_ao_state = get_prev_ao_state(last_analyze_timestamp, self.master_datadir, self.analyze_dir, self.dbname)
        prev_col_dict = get_prev_col_state(last_analyze_timestamp, self.master_datadir, self.analyze_dir, self.dbname)
        prev_last_op = get_prev_last_op(last_analyze_timestamp, self.master_datadir, self.analyze_dir, self.dbname)
        return last_analyze_timestamp, prev_ao_state, prev_col_dict, prev_last_op

    def _get_pgport(self):
        env_pgport = os.getenv('PGPORT')
        if not env_pgport:
            return self._get_master_port(self.master_datadir)
        return env_pgport

    def _get_master_port(self, datadir):
        logger.debug("Obtaining master's port from master data directory")
        pgconf_dict = pgconf.readfile(datadir + "/postgresql.conf")
        return pgconf_dict.int('port')

    def _get_input_tables(self, col_dict):
        """
        Depending on the way this program was invoked, gather all the requested tables to be analyzed.
        At the same time, parse the requested columns and populate the col_dict.
        If a requested table is partitioned, expand all the leaf partitions.
        """
        logger.info("Getting and verifying input tables...")
        if self.single_table:

            # Check that the table name given on the command line is schema-qualified.
            # XXX: Nowadays, the code should handle that just fine. Should we lift this limitation?
            # XXX: This is a fairly weak test anyway: it will be fooled by double-quoted table name
            # with a dot in it.
            if '.' not in self.single_table:
                raise ExceptionNoStackTraceNeeded("No schema name supplied for table %s" % self.single_table)

            x = validate_tables(self.conn, [self.single_table])[self.single_table]
            schema = x[0]
            table = x[1]
            # for single table, we always try to expand it to avoid getting all root partitions in the database
            self._parse_column(col_dict, self.single_table, schema, table, self.include_cols, self.exclude_cols, True)

        elif self.schema:  # all tables in a schema
            validate_schema_exists(self.pg_port, self.dbname, self.schema)
            logger.debug("getting all tables in the schema...")
            all_schema_tables = run_sql(self.conn, GET_ALL_DATA_TABLES_IN_SCHEMA_SQL % self.schema)
            # convert table name from ['public','foo'] to 'public.foo' and populate col_dict as all columns requested
            for schema_table in all_schema_tables:
                col_dict[(schema_table[0], schema_table[1])] = set(['-1'])

        elif self.config_file:
            tablenames = parse_tables_from_file(self.conn, self.config_file)
            canonical_tables = validate_tables(self.conn, tablenames)
            all_root_partitions = run_sql(self.conn, GET_ALL_ROOT_PARTITION_TABLES_SQL)
            cfg_file = open(self.config_file, 'rU')
            for line in cfg_file:
                # XXX: The file format does not allow listing tables with spaces in the name,
                # even when quoted
                toks = line.strip().split()
                orig_table = toks[0]
                (schema, table) = canonical_tables[orig_table]
                included_cols = self._get_include_or_exclude_cols(toks, '-i')
                excluded_cols = self._get_include_or_exclude_cols(toks, '-x')
                self._parse_column(col_dict, orig_table, schema, table, included_cols, excluded_cols,
                                   [orig_table] in all_root_partitions)

        else:  # all user tables in database
            alltables = run_sql(self.conn, GET_ALL_DATA_TABLES_SQL)
            for schema_table in alltables:
                col_dict[(schema_table[0], schema_table[1])] = set(['-1'])

        return col_dict.keys()

    def _get_include_or_exclude_cols(self, line_tokens, option_str):
        """
        Get included or excluded columns from a line in the config file
        :param line_tokens: tokenized lien in the config file, e.g., ['<schema>.<table>', '-i', '<col1>,<col2>,...']
        :param option_str: 'include' or 'exclude'
        :return: a list of included or excluded columns
        """
        pos = line_tokens.index(option_str) if option_str in line_tokens else -1
        if pos < 0:
            cols = None
        else:
            if pos + 1 >= len(line_tokens):
                raise Exception("No %s columns specified." % option_str)
            cols = line_tokens[pos + 1].split(',')

        return cols

    def _get_valid_candidates(self, candidates):
        """
        Validate candidates for the following:
        1. check whether candidates set is empty
        2. check invalid characters in table names
        3. skip views and external tables
        """
        qresult = run_sql(self.conn, GET_MID_LEVEL_PARTITIONS_SQL)
        mid_level_partitions = []
        for schema_tbl in qresult:
            tup = (schema_tbl[0], schema_tbl[1])
            mid_level_partitions.append(tup)

        qresult = run_sql(self.conn, GET_SCHEMA_WITH_TEMP_TABLE_SQL)
        temp_schema_set = set([x[0] for x in qresult])

        ret = set()
        for can in candidates:
            schema = can[0]
            table = can[1]
            if schema in temp_schema_set:
                continue
            if '\n' in schema or ',' in schema or ':' in schema:
                raise Exception('Schema name has an invalid character "\\n", ":", "," : "%s"' % schema)
            if '\n' in table or ',' in table or ':' in table:
                raise Exception('Table name has an invalid character "\\n", ":", "," : "%s"' % table)
            if can in mid_level_partitions:
                logger.warning("Skipping mid-level partition %s.%s" % (schema, table))
            else:
                ret.add(can)

        if self.config_file is not None or self.single_table is not None:
            valid_tables = set()
            if len(ret) > 0:
                oid_str = get_oid_str(ret)
                qresult = run_sql(self.conn, GET_VALID_DATA_TABLES_SQL % oid_str)
                for schema_tbl in qresult:
                    tup = (schema_tbl[0], schema_tbl[1])
                    valid_tables.add(tup)
            return valid_tables

        return ret

    def _get_dirty_data_tables(self, heap_tables_set, curr_ao_state, curr_last_op, prev_ao_state, prev_last_op):
        """
        dirty data tables include:
        - heap tables
        - ao tables that have gone through DML or DDL
        """
        logger.debug("getting dirty data tables...")
        dirty_heap_tables = heap_tables_set
        dirty_ao_tables = self._get_dirty_ao_state_tables(curr_ao_state, prev_ao_state)
        dirty_metadata_set = self._get_dirty_lastop_tables(curr_last_op, prev_last_op)
        return dirty_heap_tables | dirty_ao_tables | dirty_metadata_set

    def _get_ao_state(self, input_tables_set):
        logger.debug("getting ao state...")
        oid_str = get_oid_str(input_tables_set)
        ao_partition_info = run_sql(self.conn, GET_REQUESTED_AO_DATA_TABLE_INFO_SQL % oid_str)
        return get_partition_state_tuples(self.pg_port, self.dbname, 'pg_aoseg', ao_partition_info)

    def _get_lastop_state(self, input_tables_set):
        # oid, action, subtype, timestamp
        logger.debug("getting last operation states...")
        oid_str = get_oid_str(input_tables_set)
        qresult = run_sql(self.conn, GET_REQUESTED_LAST_OP_INFO_SQL % oid_str)
        ret = []
        for r in qresult:
            tup = (r[0], r[1], str(r[2]), r[3], r[4], r[5])
            ret.append(tup)
        return ret

    def _write_back(self, curr_ao_state, curr_last_op, prev_ao_state, prev_last_op, heap_partitions,
                    input_col_dict, prev_col_dict, root_partition_col_dict, is_full, dirty_partitions, target_list):

        current_time = generate_timestamp() # timestamp used for output directory
        validate_dir("%s/%s/%s/%s" % (self.master_datadir, self.analyze_dir, self.dbname, current_time))

        curr_ao_state_dict = create_ao_state_dict(curr_ao_state)
        curr_last_op_dict = create_last_op_dict(curr_last_op)
        prev_ao_state_dict = create_ao_state_dict(prev_ao_state)
        prev_last_op_dict = create_last_op_dict(prev_last_op)

        for schema_table in (x for x in self.success_list if
                             x not in heap_partitions and x not in root_partition_col_dict):
            # update modcount for tables that are successfully analyzed
            if schema_table in curr_ao_state_dict:
                new_modcount = curr_ao_state_dict[schema_table]
                prev_ao_state_dict[schema_table] = new_modcount

            # update last op for tables that are successfully analyzed
            if schema_table in curr_last_op_dict:
                last_op_info = curr_last_op_dict[schema_table]  # {'CREATE':'<entry>', 'ALTER':'<entry>', ...}
                prev_last_op_dict[schema_table] = last_op_info

            # update column dict
            if is_full or schema_table in dirty_partitions or schema_table not in prev_col_dict or '-1' in \
                    input_col_dict[schema_table]:
                prev_col_dict[schema_table] = input_col_dict[schema_table]
            else:
                prev_col_dict[schema_table] = prev_col_dict[schema_table] | input_col_dict[schema_table]

        ao_state_output = construct_entries_from_dict_aostate(prev_ao_state_dict)
        last_op_output = construct_entries_from_dict_lastop(prev_last_op_dict)
        col_state_output = construct_entries_from_dict_colstate(prev_col_dict)

        if len(ao_state_output) > 0:
            ao_state_filename = generate_statefile_name('ao', self.master_datadir, self.analyze_dir, self.dbname,
                                                        current_time)
            logger.info("Writing ao state file %s" % ao_state_filename)
            write_lines_to_file(ao_state_filename, ao_state_output)
            logger.debug("Verifying ao state file ...")
            verify_lines_in_file(ao_state_filename, ao_state_output)

        if len(last_op_output) > 0:
            last_operation_filename = generate_statefile_name('lastop', self.master_datadir, self.analyze_dir,
                                                              self.dbname, current_time)
            logger.info("Writing last operation file %s" % last_operation_filename)

            lines_to_write = map((lambda x: '%s,%s,%s,%s,%s,%s' % (x[0], x[1], x[2], x[3], x[4], x[5])), last_op_output)
            write_lines_to_file(last_operation_filename, lines_to_write)
            logger.debug("Verifying last operation file ...")
            verify_lines_in_file(last_operation_filename, lines_to_write)

        if len(prev_col_dict) > 0:
            col_state_filename = generate_statefile_name('col', self.master_datadir, self.analyze_dir, self.dbname,
                                                         current_time)
            logger.info("Writing column state file %s" % col_state_filename)
            write_lines_to_file(col_state_filename, col_state_output)
            logger.debug("Verifying column state ...")
            verify_lines_in_file(col_state_filename, col_state_output)

        report_filename = generate_statefile_name('report', self.master_datadir, self.analyze_dir, self.dbname,
                                                  current_time)
        logger.info("Writing report file %s" % report_filename)
        with open(report_filename, 'w') as fp:
            fp.write("%s:%s:%s:%s %s:%s:analyzedb %s\n\n" % (
                current_time[:8], current_time[8:10], current_time[10:12], current_time[12:14],
                unix.getLocalHostname(), unix.getUserName(), ' '.join(sys.argv[1:])))
            fp.write("Tables or partitions to analyze:\n---------------------------------------\n")
            for target in target_list:
                fp.write("%s\n" % target.strip())
            fp.write("\n\nTables or partitions successfully analyzed:\n--------------------------------------------\n")
            for schema_tbl in self.success_list:
                fp.write("%s.%s\n" % (escape_identifier(schema_tbl[0]), escape_identifier(schema_tbl[1])))
            fp.write("\n%d out of %d tables are analyzed.\n" % (len(self.success_list), len(target_list)))
            if len(target_list) == len(self.success_list):
                fp.write("\nanalyzedb finished successfully.\n")
            self._clean_stale_directories(current_time)

    def _get_dirty_lastop_tables(self, curr_last_op, prev_last_op):
        old_pgstatlastoperation_dict = get_pgstatlastoperation_dict(prev_last_op)
        dirty_tables = compare_metadata(old_pgstatlastoperation_dict, curr_last_op)
        return dirty_tables

    def _get_dirty_ao_state_tables(self, curr_ao_state, prev_ao_state):
        last_state_dict = create_ao_state_dict(prev_ao_state)
        curr_state_dict = create_ao_state_dict(curr_ao_state)
        return compare_dict(last_state_dict, curr_state_dict)

    def _parse_column(self, col_dict, orig_tablename, schema, table, include_cols, exclude_cols, is_root_partition):
        """
        Given a list of included or excluded columns of a table, populate the column dictionary.
        If the table is partitioned, expand it into all leaf partitions.
        If both include_cols and exclude_cols are empty, use '-1' as the value indicating 'all columns'.
        """
        included_column_set = set()
        if include_cols is not None:
            validate_columns(self.conn, schema, table, include_cols)
        elif exclude_cols is not None:
            validate_columns(self.conn, schema, table, exclude_cols)
            included_column_set = get_include_cols_from_exclude(self.conn, schema, table, exclude_cols)
            if len(included_column_set) == 0:
                raise Exception("All columns have been excluded from table %s" % orig_tablename)
        if is_root_partition:
            logger.debug("expanding partition tables...")
            tbl_parts = self._expand_partition_tables(schema, table)
        else:
            tbl_parts = [(schema, table)]

        for schema_tbl in tbl_parts:
            if include_cols is not None:
                col_dict[schema_tbl] = set(include_cols)
            elif exclude_cols is not None:
                col_dict[schema_tbl] = included_column_set
            else:  # all columns
                col_dict[schema_tbl] = set(['-1'])

    def _update_input_col_dict_with_column_increments(self, schema_table, input_col_dict, prev_col_dict):
        if schema_table in prev_col_dict:
            # since expanding the default '-1' to all column name is expensive, we avoid this as much as possible
            if '-1' not in input_col_dict[schema_table] or '-1' not in prev_col_dict[schema_table]:
                input_col_set = self._expand_columns(input_col_dict, schema_table)
                prev_col_set = self._expand_columns(prev_col_dict, schema_table)
                pending_cols = input_col_set - prev_col_set  # set difference
                input_col_dict[schema_table] = pending_cols
            else:  # both previous and current runs are without column specification
                input_col_dict[schema_table] = set()

    def _get_tablename_with_cols(self, schema, table, col_dict):
        s = '%s.%s' % (escape_identifier(schema), escape_identifier(table))
        cols = col_dict[(schema, table)]
        if '-1' not in cols:
            s += '(' + ','.join(sorted(map(escape_identifier, cols))) + ')'
        return s

    def _expand_partition_tables(self, schema, parent):
        qresult = run_sql(self.conn, GET_LEAF_PARTITIONS_SQL % (schema, parent, schema, parent))
        if len(qresult) == 0:
            return [(schema, parent)]
        else:
            ret = []
            for schema_tbl in qresult:
                tup = (schema_tbl[0], schema_tbl[1])
                ret.append(tup)
            return ret

    def _get_root_partition_col_dict(self, candidates, input_col_dict):
        """
        Examine the candidates and figure out the root partitions whose stats need refreshing and
        what columns need to be analyzed on those root partitions.
        If the program is invoked on whole schema or whole database, then we know all columns have
        been requested. Thus we can use one query to obtain the root partitions associated with the
        candidates.
        If the program is invoked by '-t' or '-f', we need to either look up the partition_dict or
        issue a query to get the leaf-root relationship. Then the columns to be analyzed on the root
        level are the set union of the columns to be analyzed for all leaf partitions.
        """
        logger.debug("getting mapping between leaf and root partition tables...")
        ret = {}
        # The leaf_root_dict keeps track of the mapping between a leaf partition and its root partition
        # for the use of refreshing root stats.
        leaf_root_dict = {}
        oid_str = get_oid_str(candidates)
        qresult = run_sql(self.conn, GET_LEAF_ROOT_MAPPING_SQL % oid_str)
        for mapping in qresult:
            leaf_root_dict[(mapping[0], mapping[1])] = (mapping[2], mapping[3])

        for can in candidates:
            if can in leaf_root_dict:  # this is a leaf partition
                if leaf_root_dict[can] not in ret:
                    ret[leaf_root_dict[can]] = input_col_dict[can].copy()
                else:
                    ret[leaf_root_dict[can]] |= input_col_dict[can].copy()
        ## TODO: do we need column expansion here?
        return ret

    def _get_ordered_candidates(self, candidates, root_partition_col_dict):
        """
        Take all tables in candidates and root_partition_col_dict and order them
        by descending order of their OIDs. This gives us two important benefits:
        1. The root partition will be analyzed right after the leaves
        2. The leaf partitions (if range partitioned, especially by date) will be ordered in descending
           order of the partition key, so that newer partitions can be analyzed first.
        """
        candidate_regclass_str = get_oid_str(list(candidates) + root_partition_col_dict.keys())
        qresult = run_sql(self.conn, ORDER_CANDIDATES_BY_OID_SQL % candidate_regclass_str)
        ordered_candidates = []
        for schema_tbl in qresult:
            tup = (schema_tbl[0], schema_tbl[1])
            ordered_candidates.append(tup)
        return ordered_candidates

    def _expand_columns(self, col_dict, schema_table):
        if '-1' in col_dict[schema_table]:
            cols = run_sql(self.conn, GET_COLUMN_NAMES_SQL % get_oid_str([schema_table]))
            return set([x[0] for x in cols])
        else:
            return col_dict[schema_table]

    def ensure_semaphore_file_exists(self):
        db_directory = "%s/%s/%s" % (self.master_datadir, self.analyze_dir, self.dbname)
        validate_dir(db_directory)
        lock_file_path = os.path.join(db_directory, WRITE_LOCK_FILE_NAME)

        if not os.path.exists(lock_file_path):
            with open(lock_file_path, 'w') as lock_file:
                lock_file.write("semaphore for analyzedb")

        return lock_file_path

    def _write_report(self, curr_ao_state, curr_last_op, heap_partitions, input_col_dict,
                      root_partition_col_dict, dirty_partitions, target_list):
        lock_file_path = self.ensure_semaphore_file_exists()

        with open(lock_file_path, 'r') as lock_file:
            logger.info("about to request exclusive lock on '%s' for analyzedb, to be able to write results..." % lock_file_path)
            fcntl.flock(lock_file, fcntl.LOCK_EX)  # will block until available
            logger.info("acquired analyzedb output lock, proceeding...")
            try:
                # in case of a concurrent run which has already finished,
                # update our perspective on the "last/previous" run
                last_analyze_timestamp, prev_ao_state, prev_col_dict, prev_last_op = self.read_last_analyzedb_output()

                # special case for two runs that end in the same second:
                if last_analyze_timestamp == generate_timestamp():
                    time.sleep(2)

                self._write_back(curr_ao_state, curr_last_op, prev_ao_state, prev_last_op, heap_partitions,
                                 input_col_dict, prev_col_dict, root_partition_col_dict, self.full_analyze,
                                 dirty_partitions, target_list)
            finally:
                fcntl.flock(lock_file, fcntl.LOCK_UN)

    def _clean_stale_directories(self, current_time_str):
        # extract the date and time from the string of the form yyyymmddhhmmss
        current_analyze_time = datetime.strptime(current_time_str, '%Y%m%d%H%M%S')
        num_previous_statefiles = 0
        # Now walk through the log directories and delete those that are more than
        # REPORTS_ARE_STALE_AFTER_N_DAYS days older than the current analyze timestamp,
        # but leave at least NUM_REPORTS_TO_SAVE, even if they are older than the threshold
        analyze_dirs = get_analyze_dirs(self.master_datadir, self.analyze_dir, self.dbname)
        for saved_analyze_dir in analyze_dirs:
            try:
                time_of_saved_analyze = datetime.strptime(os.path.basename(saved_analyze_dir), '%Y%m%d%H%M%S')
                time_diff = current_analyze_time - time_of_saved_analyze
                if num_previous_statefiles >= NUM_REPORTS_TO_SAVE and time_diff > timedelta(days=REPORTS_ARE_STALE_AFTER_N_DAYS):
                    # this directory more than a week older than the most recent valid analyze directory, remove it
                    if os.path.exists(saved_analyze_dir):
                        shutil.rmtree(saved_analyze_dir)
                        logger.info("Deleted archived log directory %s" % saved_analyze_dir)
                if time_diff > timedelta(seconds=0):
                    num_previous_statefiles += 1
            except:
                logger.info("Ignoring log directory that does not conform to naming convention: %s" % saved_analyze_dir)


# Create a Command object that executes a query using psql.
def create_psql_command(dbname, query):
    psql_cmd = """psql %s -c %s""" % (pipes.quote(dbname), pipes.quote(query))
    return Command(query, psql_cmd)


def run_sql(conn, query):
    try:
        cursor = dbconn.execSQL(conn, query)
        res = cursor.fetchall()
    except Exception, db_err:
        raise ExceptionNoStackTraceNeeded("%s" % db_err.__str__())  # .split('\n')[0])
    cursor.close()
    return res


def generate_timestamp():
    timestamp = datetime.now()
    return timestamp.strftime("%Y%m%d%H%M%S")


# The argument is a list of (schema, table) tuples. The output is a string containing an
# SQL expression like: to_regclass('schema.table'), that can be embedded safely in an SQL string.
# The escaping is a bit tricky here: the schema and table name need to be double-quoted, and the
# whole string needs to be in single quotes.
def get_oid_str(table_list):
    return ','.join(map((lambda x: regclass_schema_tbl(x[0], x[1])), table_list))


# Returns a string that uses to_regclass instead of ::regclass
# to_regclass returns NULL instead of an error if the table does not exist
def regclass_schema_tbl(schema, tbl):
    schema_tbl = "%s.%s" % (escape_identifier(schema), escape_identifier(tbl))

    return "to_regclass('%s')" % (pg.escape_string(schema_tbl))


# Escape double-quotes in a string, so that the resulting string is suitable for
# embedding as in SQL. Analogouous to libpq's PQescapeIdentifier
def escape_identifier(str):
    # Does the string need quoting? Simple strings with all-lower case ASCII
    # letters don't.
    SAFE_RE = re.compile('[a-z][a-z0-9_]*$')

    if SAFE_RE.match(str):
        return str

    # Otherwise we have to quote it. Any double-quotes in the string need to be escaped
    # by doubling them.
    return '"' + str.replace('"', '""') + '"'


def get_heap_tables_set(conn, input_tables_set):
    logger.debug("getting heap tables...")
    oid_str = get_oid_str(input_tables_set)
    dirty_tables = set()
    qresult = run_sql(conn, GET_REQUESTED_NON_AO_TABLES_SQL % oid_str)
    for row in qresult:
        schema_table = (row[0], row[1])
        dirty_tables.add(schema_table)
    return dirty_tables


def get_latest_analyze_timestamp(master_datadir, statefile_dir, dbname):
    analyze_dirs = get_analyze_dirs(master_datadir, statefile_dir, dbname)

    for analyze_dir in analyze_dirs:
        files = sorted(os.listdir(analyze_dir))

        if len(files) == 0:
            logger.warn('Analyze state file directory %s is empty. Ignoring this directory...' % analyze_dir)
            continue

        analyze_report_files = fnmatch.filter(files,
                                              'analyze_[0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9]_report')

        if len(analyze_report_files) == 0:
            logger.warn(
                'No analyze report files found in analyze directory %s. Ignoring this directory...' % analyze_dir)
            continue
        for report_file in analyze_report_files:
            return report_file.split('_')[1]

    return None


def get_prev_ao_state(timestamp, master_datadir, analyze_dir, dbname):
    logger.debug("getting previous ao state...")
    prev_state_filename = generate_statefile_name('ao', master_datadir, analyze_dir, dbname, timestamp)
    if not os.path.isfile(prev_state_filename):
        return []
    lines = get_lines_from_file(prev_state_filename)
    # Parse the lines into (schemaname, tablename, modcount) tuples. Each line is
    # a comma-separated string. XXX: This file format cannot deal with names
    # with commas.
    return map((lambda x: x.split(',')), lines)


def get_prev_last_op(timestamp, master_datadir, analyze_dir, dbname):
    logger.debug("getting previous last operation...")
    old_pgstatlastoperation_file = generate_statefile_name('lastop', master_datadir, analyze_dir, dbname, timestamp)
    if not os.path.isfile(old_pgstatlastoperation_file):
        old_pgstatlastoperation = []
    else:
        old_pgstatlastoperation = get_lines_from_file(old_pgstatlastoperation_file)
    # Parse the lines into tuples like:
    # [('public', 'ao_tab', 67468, 'ALTER', 'ADD COLUMN', '2014-10-15 14:49:27.658777-07'), ...]
    # Each line is a comma-separated string. XXX: This file format cannot deal with names
    # with commas.
    ret = []
    for l in old_pgstatlastoperation:
        r = l.split(',')
        tup = (r[0], r[1], r[2], r[3], r[4], r[5])
        ret.append(tup)
    return ret


def get_prev_col_state(timestamp, master_datadir, analyze_dir, dbname):
    logger.debug("getting previous column states...")
    prev_col_state_file = generate_statefile_name('col', master_datadir, analyze_dir, dbname, timestamp)
    if not os.path.isfile(prev_col_state_file):
        return {}
    lines = get_lines_from_file(prev_col_state_file)
    prev_col_dict = {}
    for line in lines:
        toks = line.strip().split(',')
        toks = map(str.strip, toks)
        prev_col_dict[(toks[0], toks[1])] = set(toks[2:])

    return prev_col_dict


def create_ao_state_dict(ao_state_entries):
    ao_state_dict = dict()
    for entry in ao_state_entries:
        key = (entry[0], entry[1])
        ao_state_dict[key] = entry[2]

    return ao_state_dict


def create_last_op_dict(last_op_entries):
    last_op_dict = {}
    for entry in last_op_entries:
        key = (entry[0], entry[1])
        op = entry[3]
        if key not in last_op_dict:
            last_op_dict[key] = {op: entry}
        else:
            last_op_dict[key][op] = entry

    return last_op_dict


def construct_entries_from_dict_aostate(ao_state_dict):
    ret = []
    for key, item in ao_state_dict.iteritems():
        schema = key[0]
        table = key[1]
        ret.append("%s,%s,%s" % (schema, table, item))
    return ret


def construct_entries_from_dict_lastop(last_op_dict):
    ret = []
    for value in last_op_dict.itervalues():
        for entry in value.itervalues():
            ret.append(entry)
    return ret


def construct_entries_from_dict_colstate(prev_col_dict):
    ret = []
    for schema_table, col_set in prev_col_dict.iteritems():
        schema = schema_table[0]
        table = schema_table[1]
        cols = ','.join(col_set)
        ret.append("%s,%s,%s" % (schema, table, cols))  # public,foo,a,b,c
    return ret


def compare_metadata(old_pgstatlastoperation, cur_pgstatlastoperation):
    diffs = set()
    for operation in cur_pgstatlastoperation:
        # operation[0] and [1] give the schema and table name respectively
        # operation[3] gives the staactionname
        # all three are required to create a unique key for a specific operation
        if (operation[0], operation[1], operation[3]) not in old_pgstatlastoperation \
                or old_pgstatlastoperation[(operation[0], operation[1], operation[3])] != operation:
            diffs.add((operation[0], operation[1]))
    return diffs


def get_pgstatlastoperation_dict(last_operations):
    last_operations_dict = {}
    for operation in last_operations:
        last_operations_dict[(operation[0], operation[1], operation[3])] = operation
    return last_operations_dict


def generate_statefile_name(type_str, master_data_dir, analyze_dir, dbname, timestamp):
    use_dir = "%s/%s/%s/%s" % (master_data_dir, analyze_dir, dbname, timestamp)
    if type_str == 'lastop':
        ret_str = "%s/analyze_%s_last_operation"
    elif type_str == 'ao':
        ret_str = "%s/analyze_%s_ao_state_file"
    elif type_str == 'col':
        ret_str = "%s/analyze_%s_col_state_file"
    elif type_str == 'report':
        ret_str = "%s/analyze_%s_report"
    else:
        raise Exception("Invalid type string for generating state file name")
    return ret_str % (use_dir, timestamp)


def get_analyze_dirs(master_datadir, statefile_dir, dbname):
    analyze_path = os.path.join(master_datadir, statefile_dir, dbname)

    if not os.path.isdir(analyze_path):
        return []

    initial_list = os.listdir(analyze_path)
    initial_list = fnmatch.filter(initial_list,
                                  '[0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9]')

    dirnames = []
    for d in initial_list:
        pth = os.path.join(analyze_path, d)
        if os.path.isdir(pth):
            dirnames.append(pth)

    if len(dirnames) == 0:
        return []
    dirnames = sorted(dirnames, key=lambda x: int(os.path.basename(x)), reverse=True)
    return dirnames


def validate_dir(path):
    exists = CheckDir(path).run()
    if exists:
        logger.info("Directory %s exists" % path)
    else:
        try:
            MakeDir(path).run()
        except OSError, e:
            logger.exception("Could not create directory %s" % path)
            raise AnalyzeDirCreateFailed()
        else:
            logger.info("Created %s" % path)
    try:
        with tempfile.TemporaryFile(dir=path) as f:
            pass
    except Exception, e:
        logger.exception("Cannot write to %s" % path)
        raise AnalyzeDirNotWritable()


# Parse a list of tables from the config file.
def parse_tables_from_file(conn, include_file):
    in_file = open(include_file, 'rU')
    line_no = 1
    tables = []
    for line in in_file:
        toks = line.strip().split()
        if len(toks) == 0:  # empty line
            continue
        if len(toks) > 3:  # we are expecting <schema>.<table> --in(ex)clude_column <col1>,<col2>,...
            raise ExceptionNoStackTraceNeeded(
                "Wrong input arguments in line %d of config file. Please check usage." % line_no)
        if '.' not in toks[0]:
            raise ExceptionNoStackTraceNeeded("No schema name supplied for table %s" % toks[0])
        if toks[0] in tables:
            raise ExceptionNoStackTraceNeeded("Duplicate table name in line %d of config file." % line_no)

        tables.append(toks[0])
        line_no += 1
    in_file.close()

    return tables


# Given a user-supplied list of table names (from command line or config file), check that
# each table exists. As a side-effect, we construct a canonicalized form of each table
# name, and return a dictionary to map from the original names to the canonicalized
# form
def validate_tables(conn, tablenames):
    """
    tables needs to be a list of 'schema.table's
    """

    # since the number of target entries cannot be greater than 1664 in GPDB/HAWQ,
    # we validate the tables in batches of 1500
    # XXX: We use a VALUES list now. What's the maximum size of that?
    batch_size = 1500
    nbatches = (len(tablenames) - 1) / batch_size + 1
    curr_batch = 0

    canonical_dict = {}

    while curr_batch < nbatches:
        batch = tablenames[curr_batch * batch_size:(curr_batch + 1) * batch_size]

        oid_str = ','.join(map((lambda x: "('%s')" % pg.escape_string(x)), batch))

        rows = run_sql(conn, VALIDATE_TABLE_NAMES_SQL % oid_str)
        curr_batch += 1
        for row in rows:
            canonical_dict[row[2]] = (row[0], row[1])

    return canonical_dict


def get_include_cols_from_exclude(conn, schema, table, exclude_cols):
    """
    Given a list of excluded columns of a table, get the list of included columns
    """
    quoted_exclude_cols = ','.join(["'%s'" % pg.escape_string(x) for x in exclude_cols])

    oid_str = regclass_schema_tbl(schema, table)
    cols = run_sql(conn, GET_INCLUDED_COLUMNS_FROM_EXCLUDE_SQL % (oid_str, quoted_exclude_cols))

    return set([x[0] for x in cols])


def validate_columns(conn, schema, table, column_list):
    """
    Check whether all column names in a list are valid for a table
    """
    if len(column_list) == 0:
        return

    sql = VALIDATE_COLUMN_NAMES_SQL % (regclass_schema_tbl(schema, table),
                                       ','.join(["'%s'" % pg.escape_string(x) for x in column_list]))
    valid_col_count = dbconn.execSQLForSingleton(conn, sql)

    if int(valid_col_count) != len(column_list):
        raise Exception(
            "Invalid input columns for table %s.%s." % (escape_identifier(schema), escape_identifier(table)))


def create_parser():
    parser = OptionParser(version='%prog version 1.0',
                          description=
                          "Analyze a database incrementally. 'Incremental' means if a table or partition has not been modified by "
                          "DML or DDL commands since the last analyzedb run, it will be automatically skipped since its statistics "
                          "must be up to date. Some restrictions apply:  "
                          "1. The incremental semantics only applies to append-only tables or partitions. All heap tables are regarded "
                          "as having stale stats every time analyzedb is run. This is because we use AO metadata to check for DML or "
                          "DDL events, which is not available to heap tables.  "
                          "2. Views, indices and external tables are automatically skipped.  "
                          "3. Table names or schema names containing comma or period is not supported yet."
                          )
    parser.set_usage('%prog [options] ')
    parser.remove_option('-h')

    parser.add_option('-d', dest='dbname', metavar="<database name>",
                      help="Database name. Required.")
    parser.add_option('-s', dest='schema', metavar="<schema name>",
                      help="Specify a schema to analyze. All tables in the schema will be analyzed.")
    parser.add_option('-t', dest='single_table', metavar="<schema name>.<table name>",
                      help="Analyze a single table. Table name needs to be qualified with schema name.")
    parser.add_option('-i', type='string', dest='include_cols', metavar="<column1>,<column2>,...",
                      help="Columns to include to be analyzed, separated by comma. All columns will be analyzed if not specified.")
    parser.add_option('-x', type='string', dest='exclude_cols', metavar="<column1>,<column2>,...",
                      help="Columns to exclude to be analyzed, separated by comma. All columns will be analyzed if not specified.")
    parser.add_option('-f', '--file', dest='config_file', metavar="<config_file>",
                      help="Config file that includes a list of tables to be analyzed. "
                           "Table names must be qualified with schema name. Optionally a list of columns (separated by "
                           "comma) can be specified using -i or -x.")
    parser.add_option('-l', '--list', action='store_true', dest='dry_run', default=False,
                      help="List the tables to be analyzed without actually running analyze (dry run).")
    parser.add_option('-p', type='int', dest='parallel_level', default=5, metavar="<parallel level>",
                      help="Parallel level, i.e. the number of tables to be analyzed in parallel. Valid numbers are between 1 and 10. Default value is 5.")
    parser.add_option('--skip_root_stats', action='store_false', dest='rootstats', default=True,
                      help="This option is no longer used. Please remove it from your scripts.")
    parser.add_option('--skip_orca_root_stats', action='store_false', dest='orca_rootstats', default=True,
                      help="Suppress generation of root partition stats for ORCA.")
    parser.add_option('--gen_profile_only', action='store_true', dest='gen_profile_only', default=False,
                      help="Create cached state files to indicate specified table or all tables have been analyzed.")
    parser.add_option('--full', action='store_true', dest='full_analyze', default=False,
                      help="Analyze without using incremental. All tables requested by the user will be analyzed.")
    parser.add_option('--clean_last', action='store_true', dest='clean_last', default=False,
                      help="Clean the state files generated by last analyzedb run. All other options except -d and -a will be ignored.")
    parser.add_option('--clean_all', action='store_true', dest='clean_all', default=False,
                      help="Clean all the state files generated by analyzedb. All other options except -d and -a will be ignored.")
    parser.add_option('-h', '-?', '--help', action='help',
                      help='Show this help message and exit.')
    parser.add_option('-v', '--verbose', action='store_true', dest='verbose', help='Print debug messages.')
    parser.add_option('-a', action='store_true', dest='silent', default=False,
                      help="Quiet mode. Do not prompt for user confirmation.")
    return parser


class AnalyzeDirCreateFailed(Exception): pass


class AnalyzeDirNotWritable(Exception): pass


class AnalyzeWorkerPool(WorkerPool):
    """
    a custom worker pool for analyze workers
    """

    def __init__(self, numWorkers=5, items=None):
        self.workers = []
        self.work_queue = Queue()
        self.completed_queue = Queue()
        self.should_stop = False
        self._assigned = 0
        self.daemonize = False
        self.logger = logger

        if items is not None:
            for item in items:
                self.addCommand(item)

        for i in range(0, numWorkers):
            # use AnalyzeWorker instead of Worker
            w = AnalyzeWorker("worker%d" % i, self)
            self.workers.append(w)
        self.numWorkers = numWorkers

    def run(self):
        for w in self.workers:
            w.start()

class AnalyzeWorker(Worker):
    """
    a custom worker thread for Analyze
    """

    def __init__(self, name, pool):
        Worker.__init__(self, name, pool)

    def run(self):
        while True:
            try:
                try:
                    self.cmd = self.pool.getNextWorkItem()
                except TypeError:
                    # misleading exception raised during interpreter shutdown
                    return

                # we must have got a command to run here
                if self.cmd is None:
                    self.logger.debug("[%s] got a None cmd" % self.name)
                    self.pool.markTaskDone()
                elif self.cmd is self.pool.halt_command:
                    self.logger.debug("[%s] got a halt cmd" % self.name)
                    self.pool.markTaskDone()
                    self.cmd = None
                    return
                elif self.pool.should_stop:
                    self.logger.debug("[%s] got cmd and pool is stopped: %s" % (self.name, self.cmd))
                    self.pool.markTaskDone()
                    self.cmd = None
                else:
                    # run the command
                    # get rid of the gucs for displaying in the log
                    cmd_display = re.sub(r'set .*;\s*', '', self.cmd.name)
                    self.logger.info("[%s] started  %s" % (self.name, cmd_display))
                    start_time = time.time()
                    self.cmd.run()
                    end_time = time.time()
                    stderr = self.cmd.get_stderr_lines()
                    if len(stderr) > 0:  # emit stderr if there is any
                        self.logger.warning('\n'.join(stderr))
                    if self.cmd.was_successful():
                        self.logger.info("[%s] finished %s. Elapsed time: %d seconds." % (self.name, cmd_display,
                                                                                          int(end_time - start_time)))
                    self.pool.addFinishedWorkItem(self.cmd)
                    self.cmd = None

            except Exception, e:
                self.logger.exception(e)
                if self.cmd:
                    self.logger.debug("[%s] finished cmd with exception: %s" % (self.name, self.cmd))
                    self.pool.addFinishedWorkItem(self.cmd)
                    self.cmd = None


if __name__ == '__main__':
    sys.argv[0] = EXECNAME
    simple_main(create_parser, AnalyzeDb)
